import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import gym
import torch.nn.functional as F

# 定义CNN策略网络
class CNNPolicy(nn.Module):
    def __init__(self, input_channels=1, action_dim=4):
        super(CNNPolicy, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 512)  # 计算展平后的维度
        self.fc2 = nn.Linear(512, action_dim)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        x = torch.flatten(x, start_dim=1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return torch.softmax(x, dim=-1)



# 定义CNN价值网络
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np


class PPOAgent(nn.Module):
    def __init__(self, obs_channels, action_dim):
        super(PPOAgent, self).__init__()
        # CNN 特征提取层
        self.conv1 = nn.Conv2d(obs_channels, 32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        self.avg = nn.AdaptiveAvgPool2d((6, 6))
        # 计算展平后的维度（假设输入为 84x84）
        self.flatten_size = 64 * 6 * 6
        self.fc = nn.Linear(self.flatten_size, 512)

        # 两个独立的线性层
        self.actor = nn.Linear(512, action_dim)  # 用于输出动作分布
        self.critic = nn.Linear(512, 1)  # 用于输出状态价值

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        x = self.avg(x)
        x = x.view(x.size(0), -1)  # 展平
        x = torch.relu(self.fc(x))

        action_logits = self.actor(x)  # 动作分布
        value = self.critic(x)  # 价值评估

        return action_logits, value


def compute_gae(rewards, values, next_value, dones, gamma=0.99, lam=0.95):
    gae = 0
    advantages = []
    for t in reversed(range(len(rewards))):
        delta = rewards[t] + gamma * next_value * (1 - dones[t]) - values[t]
        gae = delta + gamma * lam * (1 - dones[t]) * gae
        advantages.insert(0, gae)
        next_value = values[t]  # 依次往前传递
    return advantages


def ppo_train(env, total_timesteps=10000, batch_size=16, gamma=0.98, lam=0.95, clip_ratio=0.2, epochs=4):
    obs_shape = env.observation_space.shape  # 假设是 (C, H, W)
    action_dim = env.action_space.n

    agent = PPOAgent(obs_shape[2], action_dim)
    optimizer = optim.Adam(agent.parameters(), lr=5e-4)

    state_buffer, action_buffer, reward_buffer, done_buffer, value_buffer = [], [], [], [], []
    obs = env.reset()

    rewards = 0
    for step in range(total_timesteps):
        state_buffer.append(obs)
        obs_tensor = torch.tensor(obs, dtype=torch.float32).unsqueeze(0) / 255.0  # 归一化输入

        with torch.no_grad():

            action_logits, value = agent(obs_tensor)
            action_probs = F.softmax(action_logits, dim=-1)
            dist = torch.distributions.Categorical(action_probs)
            action = dist.sample()

        next_obs, reward, done, _ = env.step(action.item())
        # print(reward)
        rewards+=(reward)

        # 存储数据
        action_buffer.append(action.item())
        reward_buffer.append(reward)
        done_buffer.append(done)
        value_buffer.append(value.item())

        obs = next_obs if not done else env.reset()

        if len(state_buffer) >= batch_size:
            # 计算GAE
            next_obs_tensor = torch.tensor(next_obs, dtype=torch.float32).unsqueeze(0) / 255.0
            _, next_value = agent(next_obs_tensor)
            advantages = compute_gae(reward_buffer, value_buffer, next_value.item(), done_buffer, gamma, lam)
            advantages = torch.tensor(advantages, dtype=torch.float32)
            advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
            # print("advantages:",advantages)

            states_tensor = torch.tensor(state_buffer, dtype=torch.float32) / 255.0
            actions_tensor = torch.tensor(action_buffer, dtype=torch.long)
            values = torch.tensor(value_buffer, dtype=torch.float32)

            old_logits, _ = agent(states_tensor)
            old_probs = F.softmax(old_logits, dim=-1)
            old_dist = torch.distributions.Categorical(old_probs)
            old_log_probs = old_dist.log_prob(actions_tensor).detach()

            for _ in range(epochs):
                logits, values_pred = agent(states_tensor)
                probs = F.softmax(logits, dim=-1)
                dist = torch.distributions.Categorical(probs)
                log_probs = dist.log_prob(actions_tensor)

                ratio = torch.exp(log_probs - old_log_probs)
                surr1 = ratio * advantages
                surr2 = torch.clamp(ratio, 1 - clip_ratio, 1 + clip_ratio) * advantages

                entropy_bonus = dist.entropy().mean()
                loss_policy = -torch.min(surr1, surr2).mean()- 0.05 * entropy_bonus

                loss_value = F.mse_loss(values_pred.squeeze(), values + advantages)

                optimizer.zero_grad()
                (loss_policy + loss_value).backward()
                optimizer.step()

            state_buffer.clear()
            action_buffer.clear()
            reward_buffer.clear()
            done_buffer.clear()
            value_buffer.clear()
    print(rewards)
    return agent


# 使用自定义环境进行训练并收集状态
def ppo_state_custom(env, total_timesteps=10000):
    policy = ppo_train(env, total_timesteps=total_timesteps)

    # 收集状态数据
    state_buffer = []
    obs = env.reset()

    for _ in range(total_timesteps):

        with torch.no_grad():
            obs_tensor = torch.tensor(obs, dtype=torch.float32).unsqueeze(0)/ 255.0
            action_probs, _ = policy(obs_tensor)
            action_probs = F.softmax(action_probs, dim=-1)
            dist = torch.distributions.Categorical(action_probs)
            action = dist.sample()

            # 执行动作
            obs, _, done, _ = env.step(action.item())
            state_buffer.append(obs)

            if done:
                obs = env.reset()

    state_buffer = np.array(state_buffer)
    print(f"Collected {len(state_buffer)} states.")
    print(f"State buffer shape: {state_buffer.shape}")
    return state_buffer


# 使用自定义环境进行训练并收集状态
def ppo_state_custom1(env, total_timesteps=10000):
    policy = ppo_train(env, total_timesteps=total_timesteps)

    # 存储 (s_t, a_t, s_{t+1})
    transitions = []
    obs = env.reset()

    for _ in range(total_timesteps):

        s_t = obs  # 记录当前状态

        with torch.no_grad():
            obs_tensor = torch.tensor(obs, dtype=torch.float32).unsqueeze(0) / 255.0
            action_probs, _ = policy(obs_tensor)
            action_probs = F.softmax(action_probs, dim=-1)
            dist = torch.distributions.Categorical(action_probs)
            action = dist.sample()

        # 执行动作
        obs, _, done, _ = env.step(action.item())
        s_t1 = obs  # 记录下一个状态

        # 记录轨迹 (s_t, a_t, s_t+1)
        transitions.append((s_t, action.item(), s_t1))

        if done:
            obs = env.reset()

    print(f"Collected {len(transitions)} transitions.")
    return transitions





